package cn.camel.algorithm.tree.bst;

import java.util.ArrayList;
import java.util.List;

/**
 * @author: Camel
 * @description: 二叉搜索树实现类
 * @date: 2025/10/15
 */
public class BinarySearchTree {
    /** 根节点 */
    private TreeNode root;

    public BinarySearchTree() {
        this.root = null;
    }

    /**
     * 在二叉搜索树中查找指定值的节点
     * @param val 要查找的值
     * @return 如果找到返回对应的节点，否则返回null
     */
    public TreeNode search(int val) {
        return search(root, val);
    }

    /**
     * 递归查找指定值的节点
     * @param node 当前节点
     * @param val 要查找的值
     * @return 如果找到返回对应的节点，否则返回null
     */
    private TreeNode search(TreeNode node, int val) {
        if (node == null || node.getVal() == val) {
            return node;
        }
        
        if (val < node.getVal()) {
            return search(node.getLeft(), val);
        } else {
            return search(node.getRight(), val);
        }
    }

    /**
     * 迭代查找
     * @param val 要查找的值
     * @return 如果找到返回对应的节点，否则返回null
     */
    public TreeNode searchIterative(int val) {
        TreeNode current = root;
        while (current != null) {
            if (current.getVal() == val) { return current; }
            else if (val < current.getVal()) { current = current.getLeft(); }
            else { current = current.getRight(); }
        }
        return null;
    }

    /**
     * 插入新节点到二叉搜索树
     * @param val 要插入的值
     * @return 插入后的根节点
     */
    public TreeNode insert(int val) {
        root = insert(root, val);
        return root;
    }

    /**
     * 递归插入新节点
     * @param node 当前节点
     * @param val 要插入的值
     * @return 插入后的节点
     */
    private TreeNode insert(TreeNode node, int val) {
        // 如果当前节点为空，创建新节点
        if (node == null) {
            return new TreeNode(val);
        }

        // 根据二叉搜索树的性质递归插入
        if (val < node.getVal()) {
            TreeNode leftNode = insert(node.getLeft(), val);
            node.setLeft(leftNode);
        } else if (val > node.getVal()) {
            TreeNode rightNode = insert(node.getRight(), val);
            node.setRight(rightNode);
        }
        // 如果值已存在，不做任何操作（也可以根据需求决定是否更新）

        return node;
    }

    /**
     * 从二叉搜索树中删除指定值的节点
     * @param val 要删除的值
     * @return 删除后的根节点
     */
    public TreeNode delete(int val) {
        root = delete(root, val);
        return root;
    }

    /**
     * 递归删除指定值的节点
     * @param node 当前节点
     * @param val 要删除的值
     * @return 删除后的节点
     */
    private TreeNode delete(TreeNode node, int val) {
        // 节点为空，返回null
        if (node == null) {
            return null;
        }

        // 递归查找要删除的节点
        if (val < node.getVal()) {
            TreeNode leftNode = delete(node.getLeft(), val);
            node.setLeft(leftNode);
        } else if (val > node.getVal()) {
            TreeNode rightNode = delete(node.getRight(), val);
            node.setRight(rightNode);
        } else {
            // 找到要删除的节点
            
            // 情况1：叶子节点
            if (node.getLeft() == null && node.getRight() == null) {
                return null;
            }
            // 情况2：只有一个子节点
            else if (node.getLeft() == null) {
                return node.getRight();
            }
            else if (node.getRight() == null) {
                return node.getLeft();
            }
            // 情况3：有两个子节点
            else {
                // 找到右子树中的最小值（或左子树中的最大值）
                int minVal = findMinValue(node.getRight());
                // 用最小值替换当前节点的值
                node.setVal(minVal);
                // 删除右子树中的最小值节点
                node.setRight(delete(node.getRight(), minVal));
            }
        }
        return node;
    }

    /**
     * 查找以指定节点为根的子树中的最小值
     * @param node 子树根节点
     * @return 最小值
     */
    private int findMinValue(TreeNode node) {
        int minVal = node.getVal();
        while (node.getLeft() != null) {
            minVal = node.getLeft().getVal();
            node = node.getLeft();
        }
        return minVal;
    }

    /**
     * 查找以指定节点为根的子树中的最大值
     * @param node 子树根节点
     * @return 最大值
     */
    private int findMaxValue(TreeNode node) {
        int maxVal = node.getVal();
        while (node.getRight() != null) {
            maxVal = node.getRight().getVal();
            node = node.getRight();
        }
        return maxVal;
    }

    /**
     * 查找指定值的前驱节点
     * 前驱节点是中序遍历中在该节点之前的节点（即小于该节点的最大节点）
     * @param val 要查找前驱的值
     * @return 前驱节点，如果不存在返回null
     */
    public TreeNode predecessor(int val) {
        TreeNode node = search(val);
        if (node == null) {
            return null;
        }
        return predecessor(node);
    }

    /**
     * 查找指定节点的前驱节点
     * @param node 要查找前驱的节点
     * @return 前驱节点，如果不存在返回null
     */
    private TreeNode predecessor(TreeNode node) {
        // 如果左子树不为空，前驱是左子树中的最大值
        if (node.getLeft() != null) {
            TreeNode temp = node.getLeft();
            while (temp.getRight() != null) {
                temp = temp.getRight();
            }
            return temp;
        }

        // 否则，向上遍历直到找到一个祖先节点，该节点的右子节点是当前路径上的节点
        TreeNode parent = node.getParent();
        TreeNode current = node;
        while (parent != null && current == parent.getLeft()) {
            current = parent;
            parent = parent.getParent();
        }
        return parent;
    }

    /**
     * 查找指定值的后继节点
     * 后继节点是中序遍历中在该节点之后的节点（即大于该节点的最小节点）
     * @param val 要查找后继的值
     * @return 后继节点，如果不存在返回null
     */
    public TreeNode successor(int val) {
        TreeNode node = search(val);
        if (node == null) {
            return null;
        }
        return successor(node);
    }

    /**
     * 查找指定节点的后继节点
     * @param node 要查找后继的节点
     * @return 后继节点，如果不存在返回null
     */
    private TreeNode successor(TreeNode node) {
        // 如果右子树不为空，后继是右子树中的最小值
        if (node.getRight() != null) {
            TreeNode temp = node.getRight();
            while (temp.getLeft() != null) {
                temp = temp.getLeft();
            }
            return temp;
        }

        // 否则，向上遍历直到找到一个祖先节点，该节点的左子节点是当前路径上的节点
        TreeNode parent = node.getParent();
        TreeNode current = node;
        while (parent != null && current == parent.getRight()) {
            current = parent;
            parent = parent.getParent();
        }
        return parent;
    }

    /**
     * 范围查找，获取树中所有在[min, max]范围内的值
     * @param min 最小值（包含）
     * @param max 最大值（包含）
     * @return 范围内所有值的列表
     */
    public List<Integer> rangeSearch(int min, int max) {
        List<Integer> result = new ArrayList<>();
        rangeSearch(root, min, max, result);
        return result;
    }

    /**
     * 递归执行范围查找
     * @param node 当前节点
     * @param min 最小值（包含）
     * @param max 最大值（包含）
     * @param result 结果列表
     */
    private void rangeSearch(TreeNode node, int min, int max, List<Integer> result) {
        if (node == null) {
            return;
        }

        // 当前节点值大于最小值，需要搜索左子树
        if (node.getVal() > min) {
            rangeSearch(node.getLeft(), min, max, result);
        }

        // 当前节点值在范围内，加入结果列表
        if (node.getVal() >= min && node.getVal() <= max) {
            result.add(node.getVal());
        }

        // 当前节点值小于最大值，需要搜索右子树
        if (node.getVal() < max) {
            rangeSearch(node.getRight(), min, max, result);
        }
    }

    /**
     * 检查树是否平衡
     * 平衡的定义是：任意节点的左右子树高度差不超过1
     * @return 如果树平衡返回true，否则返回false
     */
    public boolean isBalanced() {
        return checkBalanced(root) != -1;
    }

    /**
     * 检查子树是否平衡，并返回其高度
     * 如果不平衡返回-1
     * @param node 当前节点
     * @return 如果平衡返回高度，否则返回-1
     */
    private int checkBalanced(TreeNode node) {
        if (node == null) {
            return 0;
        }

        int leftHeight = checkBalanced(node.getLeft());
        if (leftHeight == -1) {
            return -1; // 左子树不平衡
        }

        int rightHeight = checkBalanced(node.getRight());
        if (rightHeight == -1) {
            return -1; // 右子树不平衡
        }

        // 检查当前节点是否平衡
        if (Math.abs(leftHeight - rightHeight) > 1) {
            return -1; // 当前节点不平衡
        }

        return Math.max(leftHeight, rightHeight) + 1; // 返回高度
    }

        /**
         * 获取树的高度
         * @return 树的高度，如果树为空返回0
         */
        public int height() {
            return height(root);
        }

        /**
         * 计算以指定节点为根的子树高度
         * @param node 子树根节点
         * @return 子树高度，如果节点为空返回0
         */
        private int height(TreeNode node) {
            if (node == null) {
                return 0;
            }
            int leftHeight = height(node.getLeft());
            int rightHeight = height(node.getRight());
            return Math.max(leftHeight, rightHeight) + 1;
        }

    /**
     * 中序遍历二叉搜索树
     * 对于二叉搜索树，中序遍历会产生有序序列
     * @return 中序遍历的节点值列表
     */
    public List<Integer> inorderTraversal() {
        List<Integer> result = new ArrayList<>();
        inorderTraversal(root, result);
        return result;
    }

    /**
     * 递归执行中序遍历
     * @param node 当前节点
     * @param result 结果列表
     */
    private void inorderTraversal(TreeNode node, List<Integer> result) {
        if (node == null) {
            return;
        }
        inorderTraversal(node.getLeft(), result);
        result.add(node.getVal());
        inorderTraversal(node.getRight(), result);
    }

    /**
     * 获取根节点
     * @return 根节点
     */
    public TreeNode getRoot() {
        return root;
    }

    /**
     * 设置根节点
     * @param root 新的根节点
     */
    public void setRoot(TreeNode root) {
        this.root = root;
    }

    /**
     * 主方法，用于测试二叉搜索树的功能
     * @param args 命令行参数
     */
    public static void main(String[] args) {
        BinarySearchTree bst = new BinarySearchTree();
        
        // 测试插入
        bst.insert(50);
        bst.insert(30);
        bst.insert(70);
        bst.insert(20);
        bst.insert(40);
        bst.insert(60);
        bst.insert(80);
        
        // 测试中序遍历（应该是有序的）
        System.out.println("中序遍历结果：" + bst.inorderTraversal());
        
        // 测试查找
        System.out.println("查找值40的节点: " + (bst.search(40) != null ? "存在" : "不存在"));
        System.out.println("查找值45的节点: " + (bst.search(45) != null ? "存在" : "不存在"));
        
        // 测试范围查找
        System.out.println("范围查找[30, 60]: " + bst.rangeSearch(30, 60));
        
        // 测试前驱和后继
        System.out.println("值50的前驱: " + (bst.predecessor(50) != null ? bst.predecessor(50).getVal() : "不存在"));
        System.out.println("值50的后继: " + (bst.successor(50) != null ? bst.successor(50).getVal() : "不存在"));
        
        // 测试平衡检查
        System.out.println("树是否平衡: " + bst.isBalanced());
        
        // 测试删除
        bst.delete(30);
        System.out.println("删除节点30后的中序遍历: " + bst.inorderTraversal());
        
        // 测试删除根节点
        bst.delete(50);
        System.out.println("删除根节点50后的中序遍历: " + bst.inorderTraversal());
    }
}